#include "pa_v1.cuh"
                        
                        
extern "C" {
void {{func_name}}(void* out_ptr,
        void* workspace_buffer_ptr,
        void* query_ptr,
        void* key_cache_ptr,
        void* value_cache_ptr,
        int* block_tables_ptr,
        int* cu_query_lens_ptr,
        int* context_lens_ptr,
        const float* alibi_slopes_ptr,
        const float* q_scale_ptr,
        const float* k_scale_ptr,
        const float* v_scale_ptr,
        const float* fp8_out_scale_ptr,       
        float scale,
        const int max_num_blocks_per_seq,
        const int max_num_partitions,
        const float logits_soft_cap,
        const int num_seqs,
        const int num_kv_heads,
        const int num_heads,
        const int q_stride,
        const int kv_block_stride,
        const int kv_head_stride,
        const int kv_seq_stride,
        const int sliding_window,
        void* stream);
}
                        
void {{func_name}}(void* out_ptr,
        void* workspace_buffer_ptr,
        void* query_ptr,
        void* key_cache_ptr,
        void* value_cache_ptr,
        int* block_tables_ptr,
        int* cu_query_lens_ptr,
        int* context_lens_ptr,
        const float* alibi_slopes_ptr,
        const float* q_scale_ptr,
        const float* k_scale_ptr,
        const float* v_scale_ptr,
        const float* fp8_out_scale_ptr,       
        float scale,
        const int max_num_blocks_per_seq,
        const int max_num_partitions,
        const float logits_soft_cap,
        const int num_seqs,
        const int num_kv_heads,
        const int num_heads,
        const int q_stride,
        const int kv_block_stride,
        const int kv_head_stride,
        const int kv_seq_stride,
        const int sliding_window,
        void* stream)
{
    constexpr int head_size = {{head_size}};
    constexpr int PARTITION_SIZE = {{partition_size}};
    constexpr int gqa_ratio = {{gqa_ratio}};
    assert(num_heads % num_kv_heads == 0);
    {% if logits_soft_cap_enabled %}
    const float logits_soft_cap_rcp = 1.f / logits_soft_cap;
    {% else %}
    const float logits_soft_cap_rcp = 0.f;
    {% endif %}
    float* exp_sums_ptr   = reinterpret_cast<float*>(workspace_buffer_ptr);
    float* max_logits_ptr = exp_sums_ptr + (num_seqs * num_heads * max_num_partitions);
    {{dtype}}* tmp_out_ptr =
        reinterpret_cast<{{dtype}}*>(max_logits_ptr + (num_seqs * num_heads * max_num_partitions));

    ck_tile::ComposedAttention<{{"true" if logits_soft_cap_enabled else "false"}} * ck_tile::LOGITS_SOFT_CAP> variant;
    constexpr int NTHR = 256;
    dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
    dim3 block(NTHR);

    paged_attention_ll4mi_QKV_mfma16_kernel<{{dtype}},                       
                                            {{kv_dtype}},
                                            {% if fp8_kv_dtype == 'auto' %}
                                            vllm::Fp8KVCacheDataType::kAuto,
                                            {% else %}
                                            vllm::Fp8KVCacheDataType::kFp8E4M3,
                                            {% endif %}                                
                                            {{block_size}},              
                                            head_size,               
                                            NTHR,                    
                                            {{"true" if alibi_enabled else "false"}},           
                                            gqa_ratio,
                                            {{mtp}},
                                            decltype(variant),
                                            {{"true" if sliding_window_enabled else "false"}}>               
    <<<grid, block, 0, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr),                      
                                    reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr),                  
                                    reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr),                
                                    scale,                          
                                    block_tables_ptr, 
                                    cu_query_lens_ptr,
                                    context_lens_ptr,       
                                    max_num_blocks_per_seq,
                                    alibi_slopes_ptr,             
                                    q_stride,                       
                                    kv_block_stride,                
                                    kv_head_stride,                 
                                    kv_seq_stride,                  
                                    exp_sums_ptr,                   
                                    max_logits_ptr,                 
                                    tmp_out_ptr,                    
                                    logits_soft_cap,                        
                                    logits_soft_cap_rcp,                
                                    q_scale_ptr,                    
                                    k_scale_ptr,                    
                                    v_scale_ptr,                    
                                    &variant,
                                    sliding_window);
    
    dim3 reduce_grid(num_heads, num_seqs, {{mtp}});
    dim3 reduce_block(head_size);
    paged_attention_ll4mi_reduce_kernel<{{dtype}}, {{out_dtype}}, head_size, head_size, PARTITION_SIZE, {{npar_loops}}> 
    <<<reduce_grid, reduce_block, 0, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<{{out_dtype}}*>(out_ptr),                                        
                                                                              exp_sums_ptr,        
                                                                              max_logits_ptr,                                 
                                                                              tmp_out_ptr,                                   
                                                                              cu_query_lens_ptr,                                 
                                                                              context_lens_ptr,
                                                                              max_num_partitions,                            
                                                                              fp8_out_scale_ptr);
                                    
}